import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

def plot_dynamic_figure(trace,fun):
    # plt.xlim([0,5])
    # plt.ylim([-3,3])
    # plt.zlim([-10,100])
    fig=plt.figure(figsize=(16,9))
    ax=Axes3D(fig,auto_add_to_figure=False)
    ax.set_xlim3d([0,5])
    ax.set_ylim3d([-3,3])
    ax.set_zlim3d([-10,100])
    fig.add_axes(ax)
    x1=np.linspace(0,5,60)
    x2=np.linspace(-3,3,110)
    
    x1,x2=np.meshgrid(x1,x2)
    z=fun(x1,x2)
    for it in trace:

        plt.cla()
        ax.plot_surface(x1,x2,z,cstride=20, rstride=20,color='#826677',alpha=0.5)
        ax.scatter(it[:,0],it[:,1]-it[:,2],-it[:,3],c="black",s=80)

        plt.pause(0.5)

    plt.show()

